from abc import ABC
from datetime import datetime, time
from typing import Optional, Iterable

import pandas as pd
from rqalpha.const import INSTRUMENT_TYPE
from rqalpha.data.base_data_source import BaseDataSource
from rqalpha.interface import AbstractMod
from rqalpha.model.instrument import Instrument
from quant.vlogger import VLogger

__config__ = {
    'bars': None,
    'instruments': None,
}


# This module is needed because functions like get_bar is called by trading
# APIs (such as order_target_percent). If the data is only used by the
# strategy, it can be read directly from within the strategy.
class LocalDataSourceMod(AbstractMod):
    def __init__(self):
        pass

    def start_up(self, env, mod_config):
        env.set_data_source(
            LocalDataSource(env.config.base.data_bundle_path, mod_config.bars,
                            mod_config.instruments))

    def tear_down(self, code, exception=None):
        pass


class LocalDataSource(BaseDataSource, ABC):
    # Assumption on bars
    #   columns: datetime, order_book_id, open, close, low, high, volume
    #   index: (order_book_id, datetime)
    #   datetime is in datetime format, not string
    # Assumption on instruments
    #   columns: order_book_id, listed_date
    #   index: order_book_id
    def __init__(self, path, bars, instruments):
        super().__init__(path, None)
        assert bars is not None
        assert instruments is not None
        self._bars = bars
        self._instruments = {}
        for key, ins in instruments.to_dict(orient='index').items():
            try:
                ins['order_book_id'] = key
                self._instruments[key] = Instrument(ins)
            except Exception as e:
                print(ins)
                print(e)
                raise e

    def get_bar(self, instrument, dt, frequency):
        if frequency != '1d' or instrument.order_book_id not in self._instruments.keys(
        ):
            return super().get_bar(instrument, dt, frequency)

        # After trading will invoke get_bar with time 15:30:00
        order_book_id = instrument.order_book_id
        if str(dt.date()) > '2022-04-13':
            if order_book_id.endswith('XSHG') or order_book_id.endswith(
                    'XSHE'):
                order_book_id = 'SHSE.%s' % order_book_id[:6] if order_book_id.endswith(
                    'XSHG') else 'SZSE.%s' % order_book_id[:6]

        try:
            bars = self._bars.loc(
                axis=0)[pd.IndexSlice[dt, order_book_id:order_book_id]]
            if bars.empty:
                # dt in index, but no order_book_id?
                VLogger.vlog(
                    0, 'Mod error: Empty bars (%s, %s)' % (dt, order_book_id))
                raise Exception('%s, %s' % (dt, order_book_id))
            return bars.reset_index().iloc[0].to_dict()
        except KeyError as e:
            VLogger.vlog(0, 'Mod error: KeyError %s' % e)
            raise e

    def get_instruments(self, id_or_syms=None, types=None):
        # type: (Optional[Iterable[str]], Optional[Iterable[INSTRUMENT_TYPE]]) -> Iterable[Instrument]
        if id_or_syms is None:
            if types is not None and 'CONVERTIBLE' in types:
                return filter(lambda ins: (ins.bond_type == 'cb'),
                              self._instruments.values())
            else:
                return super().get_instruments(id_or_syms=id_or_syms,
                                               types=types)

        not_found = set(id_or_syms) - set(self._instruments.keys())
        results = []
        for i in set(id_or_syms) - not_found:
            results.append(self._instruments[i])
        results.extend(list(super().get_instruments(list(not_found))))
        return results

    def history_bars(self,
                     instrument,
                     bar_count,
                     frequency,
                     fields,
                     dt,
                     skip_suspended=True,
                     include_now=False,
                     adjust_type='pre',
                     adjust_orig=None):
        if frequency != '1d' or instrument.order_book_id not in self._instruments.keys(
        ):
            return super().history_bars(instrument,
                                        bar_count,
                                        frequency,
                                        fields,
                                        dt,
                                        skip_suspended=skip_suspended,
                                        include_now=include_now,
                                        adjust_type=adjust_type,
                                        adjust_orig=adjust_orig)

        fields = self._bars.columns if fields is None else fields
        try:
            return self._bars[fields].loc(
                axis=0)[pd.IndexSlice[:dt, instrument.order_book_id:instrument.
                                      order_book_id]].tail(
                                          bar_count).reset_index().to_dict()
        except KeyError as e:
            VLogger.vlog(
                0,
                'Mod error: %s, %s, %s' % (e, instrument.de_listed_date, dt))
            raise e

    def is_suspended(self, order_book_id, dates):
        if order_book_id not in self._instruments.keys():
            return super().is_suspended(order_book_id, dates)

        ins = list(self.get_instruments(id_or_syms=[order_book_id]))
        suspended = []
        for dt in dates:
            try:
                bar = self.get_bar(ins[0],
                                   datetime.combine(dt, time(15, 0, 0)), '1d')
                suspended.append(bar['volume'] == 0)
            except KeyError:
                suspended.append(True)
        return suspended

    def available_data_range(self, frequency):
        assert frequency in ['1d', '1m']
        dr = self._bars.index.get_level_values(
            self._bars.index.names.index('datetime'))
        return dr.min().date(), dr.max().date()
